{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 678.5208740234375\n",
      "1 678.3446044921875\n",
      "2 678.1679077148438\n",
      "3 677.9906005859375\n",
      "4 677.8140258789062\n",
      "5 677.6375732421875\n",
      "6 677.4603881835938\n",
      "7 677.2842407226562\n",
      "8 677.1077270507812\n",
      "9 676.9318237304688\n",
      "10 676.7547607421875\n",
      "11 676.5780639648438\n",
      "12 676.402099609375\n",
      "13 676.22607421875\n",
      "14 676.050537109375\n",
      "15 675.8742065429688\n",
      "16 675.697998046875\n",
      "17 675.5228881835938\n",
      "18 675.3466186523438\n",
      "19 675.171142578125\n",
      "20 674.9957885742188\n",
      "21 674.8201293945312\n",
      "22 674.6449584960938\n",
      "23 674.4696655273438\n",
      "24 674.293701171875\n",
      "25 674.1188354492188\n",
      "26 673.9436645507812\n",
      "27 673.7688598632812\n",
      "28 673.5933837890625\n",
      "29 673.4183349609375\n",
      "30 673.2432250976562\n",
      "31 673.0677490234375\n",
      "32 672.8930053710938\n",
      "33 672.718505859375\n",
      "34 672.5432739257812\n",
      "35 672.36865234375\n",
      "36 672.1939697265625\n",
      "37 672.01953125\n",
      "38 671.845458984375\n",
      "39 671.6705322265625\n",
      "40 671.4959716796875\n",
      "41 671.3214111328125\n",
      "42 671.1471557617188\n",
      "43 670.9730224609375\n",
      "44 670.7986450195312\n",
      "45 670.6248779296875\n",
      "46 670.4505615234375\n",
      "47 670.2765502929688\n",
      "48 670.1029052734375\n",
      "49 669.9291381835938\n",
      "50 669.7549438476562\n",
      "51 669.5812377929688\n",
      "52 669.4072265625\n",
      "53 669.2335205078125\n",
      "54 669.0597534179688\n",
      "55 668.8862915039062\n",
      "56 668.7130126953125\n",
      "57 668.5393676757812\n",
      "58 668.365478515625\n",
      "59 668.1922607421875\n",
      "60 668.0196533203125\n",
      "61 667.8458862304688\n",
      "62 667.6719360351562\n",
      "63 667.4987182617188\n",
      "64 667.3253784179688\n",
      "65 667.1525268554688\n",
      "66 666.9796752929688\n",
      "67 666.8062744140625\n",
      "68 666.6332397460938\n",
      "69 666.4602661132812\n",
      "70 666.2880249023438\n",
      "71 666.1153564453125\n",
      "72 665.9419555664062\n",
      "73 665.7698974609375\n",
      "74 665.5972290039062\n",
      "75 665.4249267578125\n",
      "76 665.2523803710938\n",
      "77 665.0807495117188\n",
      "78 664.90869140625\n",
      "79 664.7369384765625\n",
      "80 664.5653686523438\n",
      "81 664.3935546875\n",
      "82 664.2222290039062\n",
      "83 664.0504760742188\n",
      "84 663.8796997070312\n",
      "85 663.708251953125\n",
      "86 663.5371704101562\n",
      "87 663.3655395507812\n",
      "88 663.1943969726562\n",
      "89 663.0235595703125\n",
      "90 662.8519287109375\n",
      "91 662.68115234375\n",
      "92 662.5101318359375\n",
      "93 662.3391723632812\n",
      "94 662.167724609375\n",
      "95 661.9973754882812\n",
      "96 661.8262939453125\n",
      "97 661.65576171875\n",
      "98 661.4849853515625\n",
      "99 661.3147583007812\n",
      "100 661.1438598632812\n",
      "101 660.9734497070312\n",
      "102 660.8028564453125\n",
      "103 660.6320190429688\n",
      "104 660.4620361328125\n",
      "105 660.2922973632812\n",
      "106 660.1224365234375\n",
      "107 659.9530029296875\n",
      "108 659.7825927734375\n",
      "109 659.612548828125\n",
      "110 659.4423828125\n",
      "111 659.2726440429688\n",
      "112 659.102294921875\n",
      "113 658.93212890625\n",
      "114 658.7626342773438\n",
      "115 658.5927124023438\n",
      "116 658.4229125976562\n",
      "117 658.2532348632812\n",
      "118 658.0839233398438\n",
      "119 657.9139404296875\n",
      "120 657.7446899414062\n",
      "121 657.5745849609375\n",
      "122 657.4056396484375\n",
      "123 657.2360229492188\n",
      "124 657.0665283203125\n",
      "125 656.8975219726562\n",
      "126 656.7283325195312\n",
      "127 656.5584716796875\n",
      "128 656.38916015625\n",
      "129 656.219970703125\n",
      "130 656.0503540039062\n",
      "131 655.8818969726562\n",
      "132 655.7127685546875\n",
      "133 655.5438842773438\n",
      "134 655.3749389648438\n",
      "135 655.2053833007812\n",
      "136 655.0366821289062\n",
      "137 654.8676147460938\n",
      "138 654.6989135742188\n",
      "139 654.529541015625\n",
      "140 654.3614501953125\n",
      "141 654.1927490234375\n",
      "142 654.0238037109375\n",
      "143 653.85546875\n",
      "144 653.6865234375\n",
      "145 653.5183715820312\n",
      "146 653.3500366210938\n",
      "147 653.1820068359375\n",
      "148 653.0127563476562\n",
      "149 652.8450927734375\n",
      "150 652.6767578125\n",
      "151 652.50830078125\n",
      "152 652.3396606445312\n",
      "153 652.1715698242188\n",
      "154 652.0029907226562\n",
      "155 651.8350830078125\n",
      "156 651.6668090820312\n",
      "157 651.4991455078125\n",
      "158 651.3314819335938\n",
      "159 651.1629028320312\n",
      "160 650.9957885742188\n",
      "161 650.8275146484375\n",
      "162 650.6602783203125\n",
      "163 650.4924926757812\n",
      "164 650.324462890625\n",
      "165 650.1570434570312\n",
      "166 649.9893798828125\n",
      "167 649.8214111328125\n",
      "168 649.6541137695312\n",
      "169 649.4868774414062\n",
      "170 649.3196411132812\n",
      "171 649.152099609375\n",
      "172 648.9846801757812\n",
      "173 648.8172607421875\n",
      "174 648.6503295898438\n",
      "175 648.4827880859375\n",
      "176 648.3154296875\n",
      "177 648.1484985351562\n",
      "178 647.981201171875\n",
      "179 647.8145751953125\n",
      "180 647.6472778320312\n",
      "181 647.480224609375\n",
      "182 647.3131103515625\n",
      "183 647.1466674804688\n",
      "184 646.97998046875\n",
      "185 646.8133544921875\n",
      "186 646.6460571289062\n",
      "187 646.4794311523438\n",
      "188 646.3126831054688\n",
      "189 646.1461791992188\n",
      "190 645.9801025390625\n",
      "191 645.813232421875\n",
      "192 645.6470947265625\n",
      "193 645.4805297851562\n",
      "194 645.3139038085938\n",
      "195 645.1473999023438\n",
      "196 644.9813232421875\n",
      "197 644.8147583007812\n",
      "198 644.649169921875\n",
      "199 644.4830932617188\n",
      "200 644.3173217773438\n",
      "201 644.1513061523438\n",
      "202 643.9851684570312\n",
      "203 643.820068359375\n",
      "204 643.6536865234375\n",
      "205 643.4884033203125\n",
      "206 643.3225708007812\n",
      "207 643.1571044921875\n",
      "208 642.9920043945312\n",
      "209 642.82666015625\n",
      "210 642.6612548828125\n",
      "211 642.49560546875\n",
      "212 642.33056640625\n",
      "213 642.165771484375\n",
      "214 642.0003051757812\n",
      "215 641.8352661132812\n",
      "216 641.6705932617188\n",
      "217 641.5056762695312\n",
      "218 641.340576171875\n",
      "219 641.1759643554688\n",
      "220 641.0109252929688\n",
      "221 640.8458862304688\n",
      "222 640.681884765625\n",
      "223 640.5170288085938\n",
      "224 640.3519287109375\n",
      "225 640.1876831054688\n",
      "226 640.0231323242188\n",
      "227 639.8587646484375\n",
      "228 639.6945190429688\n",
      "229 639.530029296875\n",
      "230 639.3653564453125\n",
      "231 639.2012939453125\n",
      "232 639.0376586914062\n",
      "233 638.8738403320312\n",
      "234 638.7094116210938\n",
      "235 638.5454711914062\n",
      "236 638.38134765625\n",
      "237 638.218017578125\n",
      "238 638.053955078125\n",
      "239 637.8898315429688\n",
      "240 637.7268676757812\n",
      "241 637.562255859375\n",
      "242 637.3992919921875\n",
      "243 637.2353515625\n",
      "244 637.0723266601562\n",
      "245 636.9085693359375\n",
      "246 636.7447509765625\n",
      "247 636.5810546875\n",
      "248 636.4181518554688\n",
      "249 636.2550659179688\n",
      "250 636.0913696289062\n",
      "251 635.9288330078125\n",
      "252 635.7658081054688\n",
      "253 635.6022338867188\n",
      "254 635.439208984375\n",
      "255 635.2763671875\n",
      "256 635.1129760742188\n",
      "257 634.9498901367188\n",
      "258 634.7871704101562\n",
      "259 634.6248779296875\n",
      "260 634.4625244140625\n",
      "261 634.2998046875\n",
      "262 634.137451171875\n",
      "263 633.974609375\n",
      "264 633.8121948242188\n",
      "265 633.6492309570312\n",
      "266 633.4867553710938\n",
      "267 633.3247680664062\n",
      "268 633.1624755859375\n",
      "269 632.9998168945312\n",
      "270 632.8375244140625\n",
      "271 632.6759643554688\n",
      "272 632.5137939453125\n",
      "273 632.3522338867188\n",
      "274 632.1904296875\n",
      "275 632.0286865234375\n",
      "276 631.8667602539062\n",
      "277 631.705078125\n",
      "278 631.5438232421875\n",
      "279 631.3812866210938\n",
      "280 631.220458984375\n",
      "281 631.0581665039062\n",
      "282 630.8973388671875\n",
      "283 630.7360229492188\n",
      "284 630.5750732421875\n",
      "285 630.4136352539062\n",
      "286 630.2523193359375\n",
      "287 630.091064453125\n",
      "288 629.929931640625\n",
      "289 629.7691040039062\n",
      "290 629.6082153320312\n",
      "291 629.4476318359375\n",
      "292 629.2863159179688\n",
      "293 629.1256713867188\n",
      "294 628.9652099609375\n",
      "295 628.8046264648438\n",
      "296 628.6442260742188\n",
      "297 628.4837646484375\n",
      "298 628.322509765625\n",
      "299 628.1622924804688\n",
      "300 628.0015869140625\n",
      "301 627.841064453125\n",
      "302 627.680908203125\n",
      "303 627.5208740234375\n",
      "304 627.3609008789062\n",
      "305 627.2008056640625\n",
      "306 627.0407104492188\n",
      "307 626.8807983398438\n",
      "308 626.7208862304688\n",
      "309 626.5613403320312\n",
      "310 626.4013061523438\n",
      "311 626.2418212890625\n",
      "312 626.081787109375\n",
      "313 625.922607421875\n",
      "314 625.7628784179688\n",
      "315 625.6031494140625\n",
      "316 625.4443969726562\n",
      "317 625.2847900390625\n",
      "318 625.1251220703125\n",
      "319 624.9658813476562\n",
      "320 624.8065185546875\n",
      "321 624.6475830078125\n",
      "322 624.4887084960938\n",
      "323 624.3292236328125\n",
      "324 624.1704711914062\n",
      "325 624.01171875\n",
      "326 623.8527221679688\n",
      "327 623.694091796875\n",
      "328 623.53466796875\n",
      "329 623.3760986328125\n",
      "330 623.2173461914062\n",
      "331 623.0587768554688\n",
      "332 622.900146484375\n",
      "333 622.741455078125\n",
      "334 622.5830078125\n",
      "335 622.4242553710938\n",
      "336 622.2654418945312\n",
      "337 622.1073608398438\n",
      "338 621.9487915039062\n",
      "339 621.7904663085938\n",
      "340 621.6322631835938\n",
      "341 621.4736938476562\n",
      "342 621.3158569335938\n",
      "343 621.1577758789062\n",
      "344 620.99951171875\n",
      "345 620.8412475585938\n",
      "346 620.6828002929688\n",
      "347 620.5253295898438\n",
      "348 620.3673706054688\n",
      "349 620.2095336914062\n",
      "350 620.0513305664062\n",
      "351 619.8944091796875\n",
      "352 619.7359619140625\n",
      "353 619.5782470703125\n",
      "354 619.420654296875\n",
      "355 619.262939453125\n",
      "356 619.1051635742188\n",
      "357 618.9479370117188\n",
      "358 618.78955078125\n",
      "359 618.63330078125\n",
      "360 618.4744873046875\n",
      "361 618.318115234375\n",
      "362 618.1609497070312\n",
      "363 618.00341796875\n",
      "364 617.8462524414062\n",
      "365 617.6892700195312\n",
      "366 617.5322265625\n",
      "367 617.3751831054688\n",
      "368 617.2177734375\n",
      "369 617.060791015625\n",
      "370 616.90380859375\n",
      "371 616.746826171875\n",
      "372 616.5899658203125\n",
      "373 616.4335327148438\n",
      "374 616.2767944335938\n",
      "375 616.1198120117188\n",
      "376 615.9635620117188\n",
      "377 615.806884765625\n",
      "378 615.6504516601562\n",
      "379 615.4937133789062\n",
      "380 615.3368530273438\n",
      "381 615.1809692382812\n",
      "382 615.0244750976562\n",
      "383 614.8679809570312\n",
      "384 614.7120361328125\n",
      "385 614.5556030273438\n",
      "386 614.399169921875\n",
      "387 614.242431640625\n",
      "388 614.08642578125\n",
      "389 613.9308471679688\n",
      "390 613.7750244140625\n",
      "391 613.619384765625\n",
      "392 613.4638061523438\n",
      "393 613.3080444335938\n",
      "394 613.1517333984375\n",
      "395 612.99658203125\n",
      "396 612.84130859375\n",
      "397 612.6849975585938\n",
      "398 612.5296630859375\n",
      "399 612.3743896484375\n",
      "400 612.2185668945312\n",
      "401 612.0631713867188\n",
      "402 611.9082641601562\n",
      "403 611.7529296875\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "404 611.5977783203125\n",
      "405 611.4426879882812\n",
      "406 611.287841796875\n",
      "407 611.132568359375\n",
      "408 610.9777221679688\n",
      "409 610.8225708007812\n",
      "410 610.6676635742188\n",
      "411 610.5131225585938\n",
      "412 610.3583374023438\n",
      "413 610.2035522460938\n",
      "414 610.0487060546875\n",
      "415 609.89404296875\n",
      "416 609.7394409179688\n",
      "417 609.5853271484375\n",
      "418 609.43017578125\n",
      "419 609.2758178710938\n",
      "420 609.1214599609375\n",
      "421 608.9669799804688\n",
      "422 608.8129272460938\n",
      "423 608.6581420898438\n",
      "424 608.5040893554688\n",
      "425 608.35009765625\n",
      "426 608.1954345703125\n",
      "427 608.0413818359375\n",
      "428 607.8877563476562\n",
      "429 607.7330322265625\n",
      "430 607.5791015625\n",
      "431 607.4253540039062\n",
      "432 607.2716674804688\n",
      "433 607.1177368164062\n",
      "434 606.96337890625\n",
      "435 606.8092651367188\n",
      "436 606.655517578125\n",
      "437 606.5015869140625\n",
      "438 606.3477172851562\n",
      "439 606.1937255859375\n",
      "440 606.0402221679688\n",
      "441 605.8863525390625\n",
      "442 605.732177734375\n",
      "443 605.5790405273438\n",
      "444 605.4248657226562\n",
      "445 605.2711181640625\n",
      "446 605.1172485351562\n",
      "447 604.9636840820312\n",
      "448 604.8099365234375\n",
      "449 604.6566772460938\n",
      "450 604.5028686523438\n",
      "451 604.3494873046875\n",
      "452 604.1962890625\n",
      "453 604.0423583984375\n",
      "454 603.8890991210938\n",
      "455 603.7354736328125\n",
      "456 603.5830078125\n",
      "457 603.4290771484375\n",
      "458 603.2763671875\n",
      "459 603.1234130859375\n",
      "460 602.9699096679688\n",
      "461 602.8170166015625\n",
      "462 602.6638793945312\n",
      "463 602.510986328125\n",
      "464 602.3584594726562\n",
      "465 602.2050170898438\n",
      "466 602.0527954101562\n",
      "467 601.8995971679688\n",
      "468 601.747314453125\n",
      "469 601.5946655273438\n",
      "470 601.443359375\n",
      "471 601.290771484375\n",
      "472 601.138427734375\n",
      "473 600.9863891601562\n",
      "474 600.8350219726562\n",
      "475 600.682861328125\n",
      "476 600.5305786132812\n",
      "477 600.3782958984375\n",
      "478 600.226318359375\n",
      "479 600.0741577148438\n",
      "480 599.9218139648438\n",
      "481 599.7705688476562\n",
      "482 599.6192626953125\n",
      "483 599.4675903320312\n",
      "484 599.3159790039062\n",
      "485 599.1647338867188\n",
      "486 599.0134887695312\n",
      "487 598.8618774414062\n",
      "488 598.7102661132812\n",
      "489 598.5592041015625\n",
      "490 598.4073486328125\n",
      "491 598.2564697265625\n",
      "492 598.1055297851562\n",
      "493 597.954345703125\n",
      "494 597.8040771484375\n",
      "495 597.6525268554688\n",
      "496 597.5016479492188\n",
      "497 597.35009765625\n",
      "498 597.1988525390625\n",
      "499 597.0482177734375\n"
     ]
    }
   ],
   "source": [
    "device = torch.device('cpu')\n",
    "N, D_in, H, D_out = 64, 1000, 100, 10\n",
    "learning_rate = 1e-6\n",
    "\n",
    "x = torch.randn(N, D_in, device=device)\n",
    "y = torch.randn(N, D_out, device=device)\n",
    "\n",
    "model = torch.nn.Sequential(\n",
    "            torch.nn.Linear(D_in, H),\n",
    "            torch.nn.ReLU(),\n",
    "            torch.nn.Linear(H, D_out),\n",
    "        ).to(device)\n",
    "\n",
    "loss_fn = torch.nn.MSELoss(size_average=False)\n",
    "\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n",
    "\n",
    "for t in range(500):\n",
    "    y_pred = model(x)\n",
    "    loss = loss_fn(y_pred, y)\n",
    "    print(t, loss.item())\n",
    "    \n",
    "    optimizer.zero_grad()\n",
    "    loss.backward()\n",
    "    optimizer.step()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
